from typing import TypeVar

import numpy as np
import torch
from torch import Tensor

from pdbench.third_parties.esm.utils import residue_constants as RC
from pdbench.third_parties.esm.utils.structure.affine3d import Affine3D

ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)


def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D:
    N, CA, C = bb_positions.unbind(dim=-2)
    return Affine3D.from_graham_schmidt(C, CA, N)


def index_by_atom_name(
    atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
) -> ArrayOrTensor:
    squeeze = False
    if isinstance(atom_names, str):
        atom_names = [atom_names]
        squeeze = True
    indices = [RC.atom_order[atom_name] for atom_name in atom_names]
    dim = dim % atom37.ndim
    index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
    result = atom37[index]  # type: ignore
    if squeeze:
        result = result.squeeze(dim)
    return result


def get_protein_normalization_frame(coords: Tensor) -> Affine3D:
    """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates.
    Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame
    using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame.

    Args:
        coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates

    Returns:
        Affine3D: tensor of Affine3D frame
    """
    bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2)
    coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1)

    average_position_per_n_ca_c = bb_coords.masked_fill(
        ~coord_mask[..., None, None], 0
    ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8)
    frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float())

    return frame


def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor:
    """Given a set of coordinates and a single frame, apply the frame to the coordinates.

    Args:
        coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates
        frame (Affine3D): Affine3D frame

    Returns:
        torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates
    """
    coords_trans_rot = frame[..., None, None].invert().apply(coords)

    # only transform coordinates with frame that have a valid rotation
    valid_frame = frame.trans.norm(dim=-1) > 0

    is_inf = torch.isinf(coords)
    coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords)
    coords.masked_fill_(is_inf, torch.inf)

    return coords


def normalize_coordinates(coords: Tensor) -> Tensor:
    return apply_frame_to_coords(coords, get_protein_normalization_frame(coords))
